import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

import torch


def example():
    # setup cuda

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")

    # setup test model
    model = models.resnet18()
    inputs = torch.randn(5, 3, 224, 224)

    # start profiling
    with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
        with record_function("model_inference"):
            model(inputs)

    # print results
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    # finer granularity
    print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))

    # measure time on GPU
    model = models.resnet18().cuda()
    inputs = torch.randn(5, 3, 224, 224).cuda()

    with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            model(inputs)

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    # measure memory consumption
    print("### Memory Profiling ###")
    model = models.resnet18()
    inputs = torch.randn(5, 3, 224, 224)

    with profile(activities=[ProfilerActivity.CPU],
                 profile_memory=True, record_shapes=True) as prof:
        model(inputs)

    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
    print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))

    # Trace Profiling
    model = models.resnet18().cuda()
    inputs = torch.randn(5, 3, 224, 224).cuda()

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        model(inputs)

    prof.export_chrome_trace("trace.json")

    # Examine Stack traces
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, ) as prof:
        model(inputs)

    # Print aggregated stats
    print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=2))
    exit(0)
    # Schedule profiler
    print("### profiling scheduler")

    def trace_handler(p):
        output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
        print(output)
        print("trace_" + str(p.step_num) + ".json")
        p.export_chrome_trace("trace_" + str(p.step_num) + ".json")

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                 schedule=torch.profiler.schedule(
                     wait=1,
                     warmup=1,
                     active=2),
                 on_trace_ready=trace_handler) as p:
        for idx in range(8):
            model(inputs)  # here goes the train loop
            p.step()

    return 0


if __name__ == '__main__':
    example()
